import os
import numpy
import jieba

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm
import pandas as pd

from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from log import Logger
from bert import Bert
from lstm import encode_label
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

torch.cuda.set_device(1)
logger = Logger(filename="./mlp.log").get_logger()


def load_data(root, data_type) -> (pd.DataFrame, pd.DataFrame):
    """
    load raw data
        Example: df = load_data( root="./weibo", data_type="train" )
        will load raw data from "./weibo/usual_train_labeled.csv"
    :param root: data dir
    :param data_type: ( "train","eval","test")
    :return: pd.DataFrame
    """
    pth = os.path.join(root, "two_health_" + data_type + "_labeled.csv")
    _df = pd.read_csv(pth, encoding='utf-8')
    _content = _df['文本'].astype(str)
    label_index, encoder = encode_label(_df['情绪标签'].values)

    return _content, label_index


def collate_fn(examples):
    inputs = torch.tensor(numpy.array([ex[0] for ex in examples]))
    targets = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
    return inputs, targets


class MLPDataset(Dataset):
    def __init__(self, root, data_type):
        contents, self.labels = load_data(root=root, data_type=data_type)
        self.bert = Bert(using_gpu=True)
        self.contents = [self.bert.embedding(item) for item in contents]

    def __len__(self):
        return len(self.contents)

    def __getitem__(self, i):
        return self.contents[i], self.labels[i]


class MLPModel(nn.Module):
    def __init__(self, in_dim, hid_dim, num_class, dropout=0.6):
        super(MLPModel, self).__init__()
        self.fc1 = nn.Linear(in_dim, hid_dim)
        self.fc2 = nn.Linear(hid_dim, num_class)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()

    def forward(self, x):
        x = F.relu(self.fc1(x))
        outputs = F.relu(self.fc2(x))
        log_probs = F.log_softmax(outputs, dim=-1)
        _outputs = F.softmax(outputs, dim=-1)
        return log_probs, _outputs


def train(cfg):
    # setting device
    device = "cpu"
    if cfg["using_gpu"] >= 0 and torch.cuda.is_available():
        device = "cuda"
    device = torch.device(device)

    # prepare dataset
    train_dataset = MLPDataset(root="./haodf", data_type="train")
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=cfg["batch_size"],
                                   collate_fn=collate_fn,
                                   shuffle=True)

    eval_dataset = MLPDataset(root="./haodf", data_type="eval")
    eval_data_loader = DataLoader(eval_dataset,
                                  batch_size=1,
                                  collate_fn=collate_fn,
                                  shuffle=False)

    # prepare LSTM model and optim
    model = MLPModel(in_dim=cfg["embedding_dim"], hid_dim=cfg["hidden_dim"], num_class=cfg["num_class"])
    model.to(device)

    nll_loss = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(cfg["epochs"]):
        total_loss = 0
        model.train()
        for batch in tqdm(train_data_loader, desc=f"Training Epoch {epoch}"):
            inputs, targets = [x.to(device) for x in batch]
            log_probs = model(inputs)
            loss = nll_loss(log_probs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Loss: {total_loss:.2f}")
        logger.info(f"Loss: {total_loss:.2f}")

        if (epoch + 1) % 5 == 0:
            model.eval()
            acc = 0
            for batch in tqdm(eval_data_loader, desc=f"Testing"):
                inputs, targets = [x.to(device) for x in batch]
                with torch.no_grad():
                    output = model(inputs)
                    acc += (output.argmax(dim=1) == targets).sum().item()
            print(f"Acc: {acc / len(eval_data_loader):.2f}")
            logger.info(f"Acc: {acc / len(eval_data_loader):.2f}")

    torch.save(model, "bert_mlp_model.pth")


def predict(cfg):
    # setting device
    device = "cpu"
    if cfg["using_gpu"] >= 0 and torch.cuda.is_available():
        device = "cuda"
    device = torch.device(device)

    test_dataset = MLPDataset(root="./haodf", data_type="eval")
    test_data_loader = DataLoader(test_dataset,
                                  batch_size=1,
                                  collate_fn=collate_fn,
                                  shuffle=False)

    # eval
    all_targets = []
    all_output = []
    all_predicts = []

    model = torch.load("bert_mlp_model.pth", torch.device(device))
    model.eval()
    acc = 0
    for batch in tqdm(test_data_loader, desc=f"Testing"):
        inputs, targets = [x.to(device) for x in batch]
        with torch.no_grad():
            output, _o = model(inputs)
            acc += (output.argmax(dim=1) == targets).sum().item()
            all_targets.extend(targets.cpu().numpy().tolist())
            all_predicts.extend(output.argmax(dim=1).cpu().numpy().tolist())
            all_output.extend(_o.cpu().numpy().tolist())
    print(f"Acc: {acc / len(test_data_loader):.2f}")
    # print(all_targets, all_predicts, all_output)
    cm = confusion_matrix(all_targets, all_predicts)
    cm_display = ConfusionMatrixDisplay(cm).plot()
    plt.show()

    # 保存结果
    pth = os.path.join("./haodf", "two_health_" + "eval" + "_labeled.csv")
    _df = pd.read_csv(pth, encoding='utf-8')
    _df["predict"] = all_predicts
    _df["target"] = all_targets
    _df["detail"] = all_output
    _df.to_csv("./bert_predict.csv")
    print("保存成功")


if __name__ == '__main__':
    args = {
        "using_gpu": True,
        "epochs": 20,
        "embedding_dim": 768,
        "hidden_dim": 256,
        "num_class": 2,
        "batch_size": 128
    }
    # train(args)
    predict(args)
